#!/bin/bash
#SBATCH -A example              # slurm account
#SBATCH -p partition            # slurm partition name
#SBATCH -N 1                    # number of nodes
#SBATCH -t 04:00:00             # wall time
#SBATCH -J "t5x:train"          # slurm job name
#SBATCH --exclusive             # exclusive node access
#SBATCH --mem=0                 # all mem avail
#SBATCH --mail-type=FAIL        # only send email on failure
#SBATCH --overcommit            
#SBATCH --dependency=singleton  # tells slurm to run only one job with the same job name at a time
set -x

# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# File system and volume glue code
#-------------------------------------------------------------------------------
# << CHANGE ! >>
SLURM_ACCOUNT='example'
USERID='exampleperson'

# << CHANGE ! >>
CONTAINER="" # Add link to your built container

# << CHANGE ! >>
BASE_T5X_DIR="...../t5x_git" # path to your clone of the repo
BASE_TFDS_DATA_DIR=""        # path to tfds data directory
BASE_T5X_WORKSPACE_DIR="${BASE_T5X_DIR}/workspace" # path to where outputs will be dumped

# Default env variables for paths required by t5x training scripts
TFDS_DATA_DIR=/t5x_home/datasets/
T5X_DIR=/t5x_home/
T5X_WORKSPACE_DIR=/t5x_home/workspace

# Add the T5x/JAX specific mounts
MOUNTS="--container-mounts=$BASE_T5X_DIR:/$T5X_DIR,$BASE_TFDS_DATA_DIR:/$TFDS_DATA_DIR,$BASE_T5X_WORKSPACE_DIR:$T5X_WORKSPACE_DIR"

# Add T5x/JAX specific exports
EXPORTS="--export=ALL,TFDS_DATA_DIR=${TFDS_DATA_DIR},T5X_DIR=${T5X_DIR},T5X_WORKSPACE_DIR=${T5X_WORKSPACE_DIR}"
#-------------------------------------------------------------------------------

# Command line arguments needed by the underlying scripts
T5_SIZE=$1          # small, base, large, xl, xxl
PREC="$2"           # bfloat16, float32
GPUS_PER_NODE=$3    # usually 8
BSIZE_PER_GPU=$4    # local batch size/gpu
MODEL_DIR_LOCAL=$5  # directory to save checkpoints and config dump to
NUM_MICROBATCHES=$6 # number of gradient accumulation steps
MP=$7               # tensor parallel count

NUM_GPUS=$(( GPUS_PER_NODE * SLURM_JOB_NUM_NODES ))

# << CHANGE ! >>
# You can add binding to the command below with the following line (after nvidia-smi). Remove the '&&' on the next bash line.
# && bash <<path_to_bind_script>>/bind.sh --cpu=exclusive --ib=single -- \
read -r -d '' cmd <<EOF
echo "*******STARTING********" \
&& nvidia-smi \
&& bash ${T5X_DIR}/t5x/scripts_gpu/multiprocess_pretrain_pile.sh ${T5_SIZE} ${PREC} ${GPUS_PER_NODE} ${BSIZE_PER_GPU} ${MODEL_DIR_LOCAL} ${NUM_MICROBATCHES} ${MP}
EOF

# create run specific output directory for ease of analysis
mkdir -p "${BASE_T5X_WORKSPACE_DIR}/outputs/multinode/t5_${T5_SIZE}-prec_${PREC}-nodes_${SLURM_JOB_NUM_NODES}-gpus_${NUM_GPUS}-bs_${BSIZE_PER_GPU}-sl_${SL}-mp_${MP}"

# redirect both stdout and stderr in the same file for ease of analysis
OUTFILE="${BASE_T5X_WORKSPACE_DIR}/outputs/multinode/t5_${T5_SIZE}-prec_${PREC}-nodes_${SLURM_JOB_NUM_NODES}-gpus_${NUM_GPUS}-bs_${BSIZE_PER_GPU}-sl_${SL}-mp_${MP}/output-%j-%t.txt"
echo $cmd
srun --ntasks-per-node=${GPUS_PER_NODE} -o $OUTFILE -e $OUTFILE --container-image="$CONTAINER" $MOUNTS $EXPORTS bash -c "${cmd}"
set +x

